// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "remove_packing_ops.hpp"

#include "helper_ops/packed_sequence.hpp"
#include "openvino/core/graph_util.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/gru_sequence.hpp"
#include "openvino/op/lstm_sequence.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/rnn_sequence.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {

using namespace ov::pass;
using namespace ov::op;

namespace {
bool is_rnn(std::shared_ptr<Node> node) {
    if (as_type_ptr<v5::LSTMSequence>(node) || as_type_ptr<v5::RNNSequence>(node) ||
        as_type_ptr<v5::GRUSequence>(node)) {
        return true;
    }
    return false;
}
}  // namespace

MovePackThroughLstm::MovePackThroughLstm() {
    auto pack_op = pattern::wrap_type<PackPadded>();

    ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](pattern::Matcher& m) {
        auto pack = m.get_match_root();

        auto targets = pack->output(0).get_target_inputs();
        if (targets.size() != 1)
            return false;
        auto rnn = targets.begin()->get_node()->shared_from_this();
        // Input to rnn may be transposed, skipping Transpose
        if (as_type_ptr<v1::Transpose>(rnn))
            rnn = rnn->output(0).get_target_inputs().begin()->get_node()->shared_from_this();
        if (!is_rnn(rnn))
            return false;
        targets = rnn->output(0).get_target_inputs();
        if (targets.size() != 1)
            return false;

        // The rnn is followed by a transpose and a reshape (if bidirectional), or by a squeeze (if unidirectional).
        auto next = targets.begin()->get_node()->shared_from_this();
        if (as_type_ptr<v1::Transpose>(next)) {
            next = next->output(0).get_target_inputs().begin()->get_node()->shared_from_this();
            if (!as_type_ptr<v1::Reshape>(next)) {
                return false;
            }
        } else if (!as_type_ptr<v0::Squeeze>(next)) {
            return false;
        }

        // remove PackPadded from in front of the RNN
        pack->output(0).replace(pack->input_value(0));

        auto batch_sizes = pack->output(1);
        for (auto node_input : batch_sizes.get_target_inputs()) {
            auto user = node_input.get_node()->shared_from_this();
            // Make calculation of max_batch_size not depend on batch_sizes.
            // This looks for a pattern generated by code such as
            // https://github.com/pytorch/pytorch/blob/febff45/torch/nn/modules/rnn.py#L815-L815.
            //
            // Replace Gather[axis=0](batch_sizes, 0)
            // with    Gather[axis=0](ShapeOf(rnn_input), 0)
            if (const auto gather = as_type_ptr<v8::Gather>(user)) {
                if (gather->get_axis() != 0)
                    continue;
                auto rnn_shape = std::make_shared<v3::ShapeOf>(rnn->input_value(0), element::i32);
                auto indx_1 = v0::Constant::create(element::i32, Shape{}, {0});
                auto new_gather = std::make_shared<v8::Gather>(rnn_shape, indx_1, gather->input_value(2));
                copy_runtime_info_and_name(gather, {new_gather, rnn_shape, indx_1});
                replace_node(gather, new_gather);
            } else if (user == rnn) {
                node_input.replace_source_output(pack->input_value(1));
            }
        }
        // and insert new PackPadded after the RNN
        auto next_target_inputs = next->output(0).get_target_inputs();
        auto newPackPadded = std::make_shared<PackPadded>(next->output(0), pack->input_value(1));
        register_new_node(newPackPadded);

        // make things consume from the new PackPadded
        for (auto& input : next_target_inputs)
            input.replace_source_output(newPackPadded->output(0));
        pack->output(1).replace(newPackPadded->output(1));

        return true;
    };

    auto m = std::make_shared<ov::pass::pattern::Matcher>(pack_op, "ov::frontend::pytorch::pass::MovePackThroughLstm");
    this->register_matcher(m, callback);
};

RemovePackingOps::RemovePackingOps() {
    auto unpack_op = pattern::wrap_type<PadPacked>();

    ov::matcher_pass_callback callback = [](pattern::Matcher& m) {
        const auto& unpack = m.get_match_root();
        auto pack_node = unpack->input_value(0).get_node_shared_ptr();
        if (!pack_node)
            return false;
        if (as_type_ptr<v1::Transpose>(pack_node))
            pack_node = ov::as_type_ptr<PackPadded>(pack_node->input_value(0).get_node_shared_ptr());
        if (!pack_node)
            return false;

        pack_node->output(0).replace(pack_node->input_value(0));
        pack_node->output(1).replace(pack_node->input_value(1));
        unpack->output(0).replace(unpack->input_value(0));
        unpack->output(1).replace(unpack->input_value(1));

        return true;
    };

    auto m = std::make_shared<ov::pass::pattern::Matcher>(unpack_op, "ov::frontend::pytorch::pass::RemovePackingOps");
    this->register_matcher(m, callback);
};

}  // namespace pass
}  // namespace pytorch
}  // namespace frontend
}  // namespace ov
